{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7680\n",
      "1640\n",
      "1660\n"
     ]
    }
   ],
   "source": [
    "#定义数据\n",
    "class SurnameDataset(Dataset):\n",
    "    def __init__(self, part):\n",
    "        data = pd.read_csv('./data/surnames/数字化数据.csv')\n",
    "        data = data[data.part == part]\n",
    "        self.data = data\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.data.iloc[i, 0], self.data.iloc[i, 1]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "\n",
    "train_dataset = SurnameDataset(part='train')\n",
    "val_dataset = SurnameDataset(part='val')\n",
    "test_dataset = SurnameDataset(part='test')\n",
    "\n",
    "print(len(train_dataset))\n",
    "print(len(val_dataset))\n",
    "print(len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
      "\n",
      "        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]) torch.Size([100, 15, 29])\n",
      "tensor([17, 15, 12, 17,  5]) torch.Size([100])\n"
     ]
    }
   ],
   "source": [
    "#x转one hot编码\n",
    "def one_hot(data):\n",
    "    N = len(data)\n",
    "    #N句话,每句话15个词,每个词是个29维向量\n",
    "    xs = np.zeros((N, 15, 29))\n",
    "    ys = np.empty(N)\n",
    "    for i in range(N):\n",
    "        x, y = data[i]\n",
    "        ys[i] = y\n",
    "\n",
    "        x = x.split(',')\n",
    "        for j in range(min(15, len(x))):\n",
    "            xs[i, j, int(x[j]) - 1] = 1\n",
    "\n",
    "    return torch.FloatTensor(xs), torch.LongTensor(ys)\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "train_dataloader = DataLoader(dataset=train_dataset,\n",
    "                              batch_size=100,\n",
    "                              shuffle=True,\n",
    "                              drop_last=True,\n",
    "                              collate_fn=one_hot)\n",
    "\n",
    "val_dataloader = DataLoader(dataset=val_dataset,\n",
    "                            batch_size=100,\n",
    "                            shuffle=True,\n",
    "                            drop_last=True,\n",
    "                            collate_fn=one_hot)\n",
    "\n",
    "test_dataloader = DataLoader(dataset=test_dataset,\n",
    "                             batch_size=100,\n",
    "                             shuffle=True,\n",
    "                             drop_last=True,\n",
    "                             collate_fn=one_hot)\n",
    "\n",
    "#遍历数据\n",
    "for i, data in enumerate(train_dataloader):\n",
    "    x, y = data\n",
    "    print(x[:2, :2], x.shape)\n",
    "    print(y[:5], y.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0775,  0.1844, -0.1364, -0.0393,  0.0851, -0.0010,  0.1305, -0.0033,\n",
       "          0.1114, -0.0167, -0.2325, -0.1800, -0.0974, -0.0969,  0.0864, -0.0837,\n",
       "         -0.1012,  0.2431],\n",
       "        [ 0.1589,  0.1636, -0.0502, -0.1285,  0.0600, -0.0434, -0.0683, -0.0279,\n",
       "          0.0972,  0.0287, -0.2299, -0.0088,  0.0223,  0.0191, -0.0168, -0.0585,\n",
       "         -0.0709,  0.1587]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义网络模型\n",
    "class SurnameClassifier(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SurnameClassifier, self).__init__()\n",
    "\n",
    "        h = 50\n",
    "\n",
    "        #[b,h,27] -> [b,h,13]\n",
    "        self.conv1 = nn.Conv1d(in_channels=15,\n",
    "                               out_channels=h,\n",
    "                               kernel_size=5,\n",
    "                               stride=2)\n",
    "\n",
    "        #[b,h,13] -> [b,h,5]\n",
    "        self.conv2 = nn.Conv1d(in_channels=h,\n",
    "                               out_channels=h,\n",
    "                               kernel_size=5,\n",
    "                               stride=2)\n",
    "\n",
    "        #[b,h,5] -> [b,h,1]\n",
    "        self.conv3 = nn.Conv1d(in_channels=h,\n",
    "                               out_channels=h,\n",
    "                               kernel_size=5,\n",
    "                               stride=1)\n",
    "\n",
    "        #激活函数\n",
    "        self.elu = nn.ELU()\n",
    "\n",
    "        self.convnet = nn.Sequential(self.conv1, self.elu, self.conv2,\n",
    "                                     self.elu, self.conv3, self.elu)\n",
    "\n",
    "        self.fc = nn.Linear(h, 18)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #out = self.conv1(x)\n",
    "        #print(out.shape)\n",
    "\n",
    "        #out = self.conv2(out)\n",
    "        #print(out.shape)\n",
    "\n",
    "        #out = self.conv3(out)\n",
    "        #print(out.shape)\n",
    "\n",
    "        #[b,h,27] -> [b,h]\n",
    "        out = self.convnet(x).squeeze(dim=2)\n",
    "\n",
    "        #[b,h] -> [b,18]\n",
    "        out = self.fc(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "model = SurnameClassifier()\n",
    "model(torch.randn(2, 15, 29))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.025625"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test(dataloader):\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for i, data in enumerate(dataloader):\n",
    "        x, y = data\n",
    "\n",
    "        y_pred = model(x)\n",
    "        y_pred = y_pred.argmax(axis=1)\n",
    "\n",
    "        correct += (y_pred == y).sum().item()\n",
    "        total += len(y)\n",
    "\n",
    "    return correct / total\n",
    "\n",
    "\n",
    "test(val_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2.205754041671753 0.30375\n",
      "1 1.685968041419983 0.529375\n",
      "2 1.3684632778167725 0.5575\n",
      "3 1.2239166498184204 0.579375\n",
      "4 1.2784945964813232 0.595\n",
      "5 1.066778540611267 0.606875\n",
      "6 1.2280161380767822 0.635\n",
      "7 1.259406566619873 0.638125\n",
      "8 1.1004050970077515 0.6575\n",
      "9 1.0219842195510864 0.645\n"
     ]
    }
   ],
   "source": [
    "loss_func = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(10):\n",
    "    for i, data in enumerate(train_dataloader):\n",
    "        x, y = data\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x)\n",
    "\n",
    "        loss = loss_func(y_pred, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    if epoch % 1 == 0:\n",
    "        accurecy = test(val_dataloader)\n",
    "        print(epoch, loss.item(), accurecy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.659375"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(test_dataloader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "138px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": "5",
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
